b42fdc
@@ -481,6 +481,12 @@
private void semijoinRemovalBasedTransformations(OptimizeTezProcContext procCtx,
       markSemiJoinForDPP(procCtx);
       perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER, "Mark certain semijoin edges important based ");
 
+      // Remove any semi join edges from Union Op
+      perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER);
+      removeSemiJoinEdgesForUnion(procCtx);
+      perfLogger.PerfLogEnd(this.getClass().getName(), PerfLogger.TEZ_COMPILER,
+                            "Remove any semi join edge between Union and RS");
+
       // Remove any parallel edge between semijoin and mapjoin.
       perfLogger.PerfLogBegin(this.getClass().getName(), PerfLogger.TEZ_COMPILER);
       removeSemijoinsParallelToMapJoin(procCtx);
@@ -1313,6 +1319,56 @@
private boolean findParallelSemiJoinBranch(Operator<?> mapjoin, TableScanOperato
     return parallelEdges;
   }
 
+  /*
+   * Given an operator this method removes all semi join edges downstream (children) until it hits RS
+   */
+  private void removeSemiJoinEdges(Operator<?> op, OptimizeTezProcContext procCtx,
+                                   Map<ReduceSinkOperator, TableScanOperator> sjToRemove) throws SemanticException {
+    if(op instanceof ReduceSinkOperator && op.getNumChild() == 0) {
+      Map<ReduceSinkOperator, SemiJoinBranchInfo> sjMap = procCtx.parseContext.getRsToSemiJoinBranchInfo();
+      if(sjMap.get(op) != null) {
+        sjToRemove.put((ReduceSinkOperator)op, sjMap.get(op).getTsOp());
+      }
+    }
+
+    for(Operator<?> child:op.getChildOperators()) {
+      removeSemiJoinEdges(child, procCtx, sjToRemove);
+    }
+  }
+
+  private void removeSemiJoinEdgesForUnion(OptimizeTezProcContext procCtx) throws SemanticException{
+    // Get all the TS ops.
+    List<Operator<?>> topOps = new ArrayList<>();
+    topOps.addAll(procCtx.parseContext.getTopOps().values());
+    Set<Operator<?>> unionOps = new HashSet<>();
+
+    Map<ReduceSinkOperator, TableScanOperator> sjToRemove = new HashMap<>();
+    for (Operator<?> parent : topOps) {
+      Deque<Operator<?>> deque = new LinkedList<>();
+      deque.add(parent);
+      while (!deque.isEmpty()) {
+        Operator<?> op = deque.pollLast();
+        if (op instanceof UnionOperator && !unionOps.contains(op)) {
+          unionOps.add(op);
+          removeSemiJoinEdges(op, procCtx, sjToRemove);
+        }
+        deque.addAll(op.getChildOperators());
+      }
+    }
+    // remove sj
+    if (sjToRemove.size() > 0) {
+      for (Map.Entry<ReduceSinkOperator, TableScanOperator> entry : sjToRemove.entrySet()) {
+        if (LOG.isDebugEnabled()) {
+          LOG.debug("Semijoin optimization with Union operator. Removing semijoin "
+                        + OperatorUtils.getOpNamePretty(entry.getKey()) + " - "
+                        + OperatorUtils.getOpNamePretty(sjToRemove.get(entry.getKey())));
+        }
+        GenTezUtils.removeBranch(entry.getKey());
+        GenTezUtils.removeSemiJoinOperator(procCtx.parseContext, entry.getKey(), entry.getValue());
+      }
+    }
+  }
+
   /*
    *  The algorithm looks at all the mapjoins in the operator pipeline until
    *  it hits RS Op and for each mapjoin examines if it has paralllel semijoin
